#include <torch/extension.h>


torch::Tensor sampled_max_abs(torch::Tensor matrix, int tile_count, int tile_length, int seed);

std::tuple<torch::Tensor, torch::Tensor> sampled_scale_grid_search(
        torch::Tensor matrix, float absmax, int n_grid, int tile_count, int tile_length, int seed);

torch::Tensor quantize_tensor(torch::Tensor input_tensor, float scale);

std::tuple<torch::Tensor, torch::Tensor> grid_search_quant_int8(
    torch::Tensor input_tensor, int n_grid, float sampling, int seed, bool do_quant);


torch::Tensor sampled_max_abs_f(torch::Tensor matrix, int tile_count, int tile_length, int seed);

std::tuple<torch::Tensor, torch::Tensor> sampled_scale_grid_search_f(
        torch::Tensor matrix, float absmax, int n_grid, int tile_count, int tile_length, int seed);

torch::Tensor quantize_tensor_f(torch::Tensor input_tensor, float scale);

std::tuple<torch::Tensor, torch::Tensor> grid_search_quant_int8_f(
    torch::Tensor input_tensor, int n_grid, float sampling, int seed, bool do_quant);
